{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download/Saving CIFAR-10 images in Inception format\n",
    "--------------------------------------\n",
    "\n",
    "In this script, we download the CIFAR-10 images and transform/save them in the Inception Retraining Format.  The end purpose of the files is for re-training the Google Inception tensorflow model to work on the CIFAR-10.\n",
    "\n",
    "We start by loading the necessary libraries and clearing out any current default computational graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tarfile\n",
    "import _pickle as cPickle\n",
    "import numpy as np\n",
    "import urllib.request\n",
    "import scipy.misc\n",
    "from tensorflow.python.framework import ops\n",
    "ops.reset_default_graph()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)\n",
      "This may take a few minutes, please wait.\n"
     ]
    }
   ],
   "source": [
    "cifar_link = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'\n",
    "data_dir = 'temp'\n",
    "if not os.path.isdir(data_dir):\n",
    "    os.makedirs(data_dir)\n",
    "\n",
    "# Download tar file\n",
    "target_file = os.path.join(data_dir, 'cifar-10-python.tar.gz')\n",
    "if not os.path.isfile(target_file):\n",
    "    print('CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)')\n",
    "    print('This may take a few minutes, please wait.')\n",
    "    filename, headers = urllib.request.urlretrieve(cifar_link, target_file)\n",
    "\n",
    "# Extract into memory\n",
    "tar = tarfile.open(target_file)\n",
    "tar.extractall(path=data_dir)\n",
    "tar.close()\n",
    "objects = ['airplane', 'automobile', 'bird', 'cat', 'deer',\n",
    "           'dog', 'frog', 'horse', 'ship', 'truck']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we save the images in the corresponding train or test folders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create train image folders\n",
    "train_folder = 'train_dir'\n",
    "if not os.path.isdir(os.path.join(data_dir, train_folder)):\n",
    "    for i in range(10):\n",
    "        folder = os.path.join(data_dir, train_folder, objects[i])\n",
    "        os.makedirs(folder)\n",
    "# Create test image folders\n",
    "test_folder = 'validation_dir'\n",
    "if not os.path.isdir(os.path.join(data_dir, test_folder)):\n",
    "    for i in range(10):\n",
    "        folder = os.path.join(data_dir, test_folder, objects[i])\n",
    "        os.makedirs(folder)\n",
    "\n",
    "# Extract images accordingly\n",
    "data_location = os.path.join(data_dir, 'cifar-10-batches-py')\n",
    "train_names = ['data_batch_' + str(x) for x in range(1,6)]\n",
    "test_names = ['test_batch']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we create a function that will load the images from the files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_batch_from_file(file):\n",
    "    file_conn = open(file, 'rb')\n",
    "    image_dictionary = cPickle.load(file_conn, encoding='latin1')\n",
    "    file_conn.close()\n",
    "    return image_dictionary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a function that will save the images from a dictionary into correctly formatted image files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_images_from_dict(image_dict, folder='data_dir'):\n",
    "    # image_dict.keys() = 'labels', 'filenames', 'data', 'batch_label'\n",
    "    for ix, label in enumerate(image_dict['labels']):\n",
    "        folder_path = os.path.join(data_dir, folder, objects[label])\n",
    "        filename = image_dict['filenames'][ix]\n",
    "        #Transform image data\n",
    "        image_array = image_dict['data'][ix]\n",
    "        image_array.resize([3, 32, 32])\n",
    "        # Save image\n",
    "        output_location = os.path.join(folder_path, filename)\n",
    "        scipy.misc.imsave(output_location,image_array.transpose())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we sort the training and test images into the corresponding folders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving images from file: data_batch_1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py:11: DeprecationWarning: `imsave` is deprecated!\n",
      "`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
      "Use ``imageio.imwrite`` instead.\n",
      "  # This is added back by InteractiveShellApp.init_path()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving images from file: data_batch_2\n",
      "Saving images from file: data_batch_3\n",
      "Saving images from file: data_batch_4\n",
      "Saving images from file: data_batch_5\n",
      "Saving images from file: test_batch\n"
     ]
    }
   ],
   "source": [
    "# Sort train images\n",
    "for file in train_names:\n",
    "    print('Saving images from file: {}'.format(file))\n",
    "    file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)\n",
    "    image_dict = load_batch_from_file(file_location)\n",
    "    save_images_from_dict(image_dict, folder=train_folder)\n",
    "\n",
    "# Sort test images\n",
    "for file in test_names:\n",
    "    print('Saving images from file: {}'.format(file))\n",
    "    file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)\n",
    "    image_dict = load_batch_from_file(file_location)\n",
    "    save_images_from_dict(image_dict, folder=test_folder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Correctly label the images/folders in a text file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing labels file, temp/cifar10_labels.txt\n"
     ]
    }
   ],
   "source": [
    "# Create labels file\n",
    "cifar_labels_file = os.path.join(data_dir,'cifar10_labels.txt')\n",
    "print('Writing labels file, {}'.format(cifar_labels_file))\n",
    "with open(cifar_labels_file, 'w') as labels_file:\n",
    "    for item in objects:\n",
    "        labels_file.write(\"{}\\n\".format(item))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have the images in the following format:\n",
    "\n",
    "```\n",
    "-train_dir\n",
    "  |--airplane\n",
    "  |--automobile\n",
    "  |--bird\n",
    "  |--cat\n",
    "  |--deer\n",
    "  |--dog\n",
    "  |--frog\n",
    "  |--horse\n",
    "  |--ship\n",
    "  |--truck\n",
    "-validation_dir\n",
    "  |--airplane\n",
    "  |--automobile\n",
    "  |--bird\n",
    "  |--cat\n",
    "  |--deer\n",
    "  |--dog\n",
    "  |--frog\n",
    "  |--horse\n",
    "  |--ship\n",
    "  |--truck\n",
    "```\n",
    "\n",
    "After this is done, we proceed with the [TensorFlow fine-tuning tutorial](https://www.tensorflow.org/tutorials/image_retraining).\n",
    "\n",
    "https://www.tensorflow.org/tutorials/image_retraining"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
